import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

class GazeDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.label_to_index = {}
        self.index_to_label = {}

        # 创建标签到整数的映射
        label_dirs = sorted(os.listdir(data_dir))
        for idx, label_dir in enumerate(label_dirs):
            self.label_to_index[label_dir] = idx
            self.index_to_label[idx] = label_dir
            label_path = os.path.join(data_dir, label_dir)
            if os.path.isdir(label_path):
                for img_name in os.listdir(label_path):
                    self.image_paths.append(os.path.join(label_path, img_name))
                    self.labels.append(idx)  # 存储整数标签
        # 打印标签到整数的映射
        print("Label to Index Mapping:", self.label_to_index)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx])
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = GazeDataset(data_dir='data/train', transform=transform)
test_dataset = GazeDataset(data_dir='data/test', transform=transform)
val_dataset = GazeDataset(data_dir='data/test', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

